{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "import json\n",
    "from IPython.display import display\n",
    "from hamilton import driver\n",
    "\n",
    "import __init__ as xgboost_optuna"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load config examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'task': 'classification'},\n",
      " {'task': 'regression'}]\n"
     ]
    }
   ],
   "source": [
    "def read_jsonl(file_path: str) -> list:\n",
    "    data = []\n",
    "    with open(file_path, \"r\") as f:\n",
    "        for line in f.readlines():\n",
    "            data.append(json.loads(line))\n",
    "    return data\n",
    "\n",
    "valid_configs = read_jsonl(\"valid_configs.jsonl\")\n",
    "pprint(valid_configs, width=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Note: Hamilton collects completely anonymous data about usage. This will help us improve Hamilton over time. See https://github.com/apache/hamilton#usage-analytics--data-privacy for details.\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.43.0 (0)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"1782pt\" height=\"860pt\"\n",
       " viewBox=\"0.00 0.00 1781.50 859.50\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 855.5)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-855.5 1777.5,-855.5 1777.5,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"1245,-767.5 1245,-843.5 1496,-843.5 1496,-767.5 1245,-767.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"1370.5\" y=\"-828.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- base_model -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>base_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1296.5,-542.5C1296.5,-542.5 1205.5,-542.5 1205.5,-542.5 1199.5,-542.5 1193.5,-536.5 1193.5,-530.5 1193.5,-530.5 1193.5,-490.5 1193.5,-490.5 1193.5,-484.5 1199.5,-478.5 1205.5,-478.5 1205.5,-478.5 1296.5,-478.5 1296.5,-478.5 1302.5,-478.5 1308.5,-484.5 1308.5,-490.5 1308.5,-490.5 1308.5,-530.5 1308.5,-530.5 1308.5,-536.5 1302.5,-542.5 1296.5,-542.5\"/>\n",
       "<text text-anchor=\"start\" x=\"1204.5\" y=\"-521.3\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">base_model</text>\n",
       "<text text-anchor=\"start\" x=\"1223.5\" y=\"-493.3\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Callable</text>\n",
       "</g>\n",
       "<!-- hyperparameter_search -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>hyperparameter_search</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1024,-438C1024,-438 838,-438 838,-438 832,-438 826,-432 826,-426 826,-426 826,-386 826,-386 826,-380 832,-374 838,-374 838,-374 1024,-374 1024,-374 1030,-374 1036,-380 1036,-386 1036,-386 1036,-426 1036,-426 1036,-432 1030,-438 1024,-438\"/>\n",
       "<text text-anchor=\"start\" x=\"837\" y=\"-416.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">hyperparameter_search</text>\n",
       "<text text-anchor=\"start\" x=\"918\" y=\"-388.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- base_model&#45;&gt;hyperparameter_search -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>base_model&#45;&gt;hyperparameter_search</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1207.05,-478.32C1199.89,-474.1 1192.4,-470.15 1185,-467 1141.19,-448.37 1090.57,-434.93 1046.09,-425.55\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1046.62,-422.09 1036.12,-423.5 1045.2,-428.94 1046.62,-422.09\"/>\n",
       "</g>\n",
       "<!-- best_model -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>best_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1167,-250C1167,-250 1079,-250 1079,-250 1073,-250 1067,-244 1067,-238 1067,-238 1067,-198 1067,-198 1067,-192 1073,-186 1079,-186 1079,-186 1167,-186 1167,-186 1173,-186 1179,-192 1179,-198 1179,-198 1179,-238 1179,-238 1179,-244 1173,-250 1167,-250\"/>\n",
       "<text text-anchor=\"start\" x=\"1078\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">best_model</text>\n",
       "<text text-anchor=\"start\" x=\"1087\" y=\"-200.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">XGBModel</text>\n",
       "</g>\n",
       "<!-- base_model&#45;&gt;best_model -->\n",
       "<g id=\"edge12\" class=\"edge\">\n",
       "<title>base_model&#45;&gt;best_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1257.12,-478.23C1264.58,-430.96 1271.74,-340.33 1232,-279 1221.59,-262.94 1205.34,-250.81 1188.52,-241.81\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1189.67,-238.47 1179.16,-237.13 1186.54,-244.73 1189.67,-238.47\"/>\n",
       "</g>\n",
       "<!-- optuna_distributions -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>optuna_distributions</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M868,-542.5C868,-542.5 708,-542.5 708,-542.5 702,-542.5 696,-536.5 696,-530.5 696,-530.5 696,-490.5 696,-490.5 696,-484.5 702,-478.5 708,-478.5 708,-478.5 868,-478.5 868,-478.5 874,-478.5 880,-484.5 880,-490.5 880,-490.5 880,-530.5 880,-530.5 880,-536.5 874,-542.5 868,-542.5\"/>\n",
       "<text text-anchor=\"start\" x=\"707\" y=\"-521.3\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">optuna_distributions</text>\n",
       "<text text-anchor=\"start\" x=\"775\" y=\"-493.3\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- optuna_distributions&#45;&gt;hyperparameter_search -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>optuna_distributions&#45;&gt;hyperparameter_search</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M831.29,-478.47C846.27,-467.73 863.22,-455.58 878.84,-444.39\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"881.27,-446.95 887.36,-438.28 877.19,-441.26 881.27,-446.95\"/>\n",
       "</g>\n",
       "<!-- best_hyperparameters -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>best_hyperparameters</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1211,-344C1211,-344 1035,-344 1035,-344 1029,-344 1023,-338 1023,-332 1023,-332 1023,-292 1023,-292 1023,-286 1029,-280 1035,-280 1035,-280 1211,-280 1211,-280 1217,-280 1223,-286 1223,-292 1223,-292 1223,-332 1223,-332 1223,-338 1217,-344 1211,-344\"/>\n",
       "<text text-anchor=\"start\" x=\"1034\" y=\"-322.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">best_hyperparameters</text>\n",
       "<text text-anchor=\"start\" x=\"1110\" y=\"-294.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- hyperparameter_search&#45;&gt;best_hyperparameters -->\n",
       "<g id=\"edge16\" class=\"edge\">\n",
       "<title>hyperparameter_search&#45;&gt;best_hyperparameters</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M995.73,-373.98C1012.89,-365.76 1031.53,-356.83 1049.13,-348.4\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1050.71,-351.52 1058.22,-344.04 1047.69,-345.21 1050.71,-351.52\"/>\n",
       "</g>\n",
       "<!-- study_results -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>study_results</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M560.5,-344C560.5,-344 457.5,-344 457.5,-344 451.5,-344 445.5,-338 445.5,-332 445.5,-332 445.5,-292 445.5,-292 445.5,-286 451.5,-280 457.5,-280 457.5,-280 560.5,-280 560.5,-280 566.5,-280 572.5,-286 572.5,-292 572.5,-292 572.5,-332 572.5,-332 572.5,-338 566.5,-344 560.5,-344\"/>\n",
       "<text text-anchor=\"start\" x=\"456.5\" y=\"-322.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">study_results</text>\n",
       "<text text-anchor=\"start\" x=\"488.5\" y=\"-294.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Study</text>\n",
       "</g>\n",
       "<!-- hyperparameter_search&#45;&gt;study_results -->\n",
       "<g id=\"edge17\" class=\"edge\">\n",
       "<title>hyperparameter_search&#45;&gt;study_results</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M825.63,-393.66C758.1,-384.81 668.81,-370.15 582.46,-344.88\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"583.35,-341.5 572.76,-342 581.35,-348.2 583.35,-341.5\"/>\n",
       "</g>\n",
       "<!-- higher_is_better -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>higher_is_better</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M804.5,-690C804.5,-690 679.5,-690 679.5,-690 673.5,-690 667.5,-684 667.5,-678 667.5,-678 667.5,-638 667.5,-638 667.5,-632 673.5,-626 679.5,-626 679.5,-626 804.5,-626 804.5,-626 810.5,-626 816.5,-632 816.5,-638 816.5,-638 816.5,-678 816.5,-678 816.5,-684 810.5,-690 804.5,-690\"/>\n",
       "<text text-anchor=\"start\" x=\"678.5\" y=\"-668.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">higher_is_better</text>\n",
       "<text text-anchor=\"start\" x=\"727\" y=\"-640.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">bool</text>\n",
       "</g>\n",
       "<!-- study -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>study</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M952,-542.5C952,-542.5 910,-542.5 910,-542.5 904,-542.5 898,-536.5 898,-530.5 898,-530.5 898,-490.5 898,-490.5 898,-484.5 904,-478.5 910,-478.5 910,-478.5 952,-478.5 952,-478.5 958,-478.5 964,-484.5 964,-490.5 964,-490.5 964,-530.5 964,-530.5 964,-536.5 958,-542.5 952,-542.5\"/>\n",
       "<text text-anchor=\"start\" x=\"909\" y=\"-521.3\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">study</text>\n",
       "<text text-anchor=\"start\" x=\"910.5\" y=\"-493.3\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Study</text>\n",
       "</g>\n",
       "<!-- higher_is_better&#45;&gt;study -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>higher_is_better&#45;&gt;study</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M772.48,-625.77C787.43,-611.55 806.23,-595.21 825,-583 851.18,-565.97 863.37,-571.85 889,-554 891.09,-552.55 893.16,-550.98 895.2,-549.34\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"897.6,-551.89 902.88,-542.71 893.02,-546.6 897.6,-551.89\"/>\n",
       "</g>\n",
       "<!-- study&#45;&gt;hyperparameter_search -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>study&#45;&gt;hyperparameter_search</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M931,-478.47C931,-469.02 931,-458.49 931,-448.47\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"934.5,-448.28 931,-438.28 927.5,-448.28 934.5,-448.28\"/>\n",
       "</g>\n",
       "<!-- y_test_pred -->\n",
       "<g id=\"node13\" class=\"node\">\n",
       "<title>y_test_pred</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1167.5,-157C1167.5,-157 1078.5,-157 1078.5,-157 1072.5,-157 1066.5,-151 1066.5,-145 1066.5,-145 1066.5,-105 1066.5,-105 1066.5,-99 1072.5,-93 1078.5,-93 1078.5,-93 1167.5,-93 1167.5,-93 1173.5,-93 1179.5,-99 1179.5,-105 1179.5,-105 1179.5,-145 1179.5,-145 1179.5,-151 1173.5,-157 1167.5,-157\"/>\n",
       "<text text-anchor=\"start\" x=\"1077.5\" y=\"-135.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">y_test_pred</text>\n",
       "<text text-anchor=\"start\" x=\"1095.5\" y=\"-107.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">ndarray</text>\n",
       "</g>\n",
       "<!-- best_model&#45;&gt;y_test_pred -->\n",
       "<g id=\"edge20\" class=\"edge\">\n",
       "<title>best_model&#45;&gt;y_test_pred</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1123,-185.94C1123,-180 1123,-173.7 1123,-167.49\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1126.5,-167.23 1123,-157.23 1119.5,-167.23 1126.5,-167.23\"/>\n",
       "</g>\n",
       "<!-- task -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>task</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"1229,-819 1181,-819 1181,-769 1235,-769 1235,-813 1229,-819\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"1229,-819 1229,-813 \"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"1235,-813 1229,-813 \"/>\n",
       "<text text-anchor=\"start\" x=\"1191\" y=\"-804.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">task</text>\n",
       "<text text-anchor=\"start\" x=\"1194\" y=\"-776.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Any</text>\n",
       "</g>\n",
       "<!-- best_hyperparameters&#45;&gt;best_model -->\n",
       "<g id=\"edge14\" class=\"edge\">\n",
       "<title>best_hyperparameters&#45;&gt;best_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1123,-279.85C1123,-273.56 1123,-266.85 1123,-260.27\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1126.5,-260.11 1123,-250.11 1119.5,-260.11 1126.5,-260.11\"/>\n",
       "</g>\n",
       "<!-- study_results_df -->\n",
       "<g id=\"node15\" class=\"node\">\n",
       "<title>study_results_df</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M571.5,-250C571.5,-250 446.5,-250 446.5,-250 440.5,-250 434.5,-244 434.5,-238 434.5,-238 434.5,-198 434.5,-198 434.5,-192 440.5,-186 446.5,-186 446.5,-186 571.5,-186 571.5,-186 577.5,-186 583.5,-192 583.5,-198 583.5,-198 583.5,-238 583.5,-238 583.5,-244 577.5,-250 571.5,-250\"/>\n",
       "<text text-anchor=\"start\" x=\"445.5\" y=\"-228.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">study_results_df</text>\n",
       "<text text-anchor=\"start\" x=\"470.5\" y=\"-200.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- study_results&#45;&gt;study_results_df -->\n",
       "<g id=\"edge23\" class=\"edge\">\n",
       "<title>study_results&#45;&gt;study_results_df</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M509,-279.85C509,-273.56 509,-266.85 509,-260.27\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"512.5,-260.11 509,-250.11 505.5,-260.11 512.5,-260.11\"/>\n",
       "</g>\n",
       "<!-- scorer -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>scorer</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1151,-826C1151,-826 1103,-826 1103,-826 1097,-826 1091,-820 1091,-814 1091,-814 1091,-774 1091,-774 1091,-768 1097,-762 1103,-762 1103,-762 1151,-762 1151,-762 1157,-762 1163,-768 1163,-774 1163,-774 1163,-814 1163,-814 1163,-820 1157,-826 1151,-826\"/>\n",
       "<text text-anchor=\"start\" x=\"1102\" y=\"-804.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scorer</text>\n",
       "<text text-anchor=\"start\" x=\"1114\" y=\"-776.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- scorer&#45;&gt;higher_is_better -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>scorer&#45;&gt;higher_is_better</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1090.62,-791.89C1031.81,-788.79 913.52,-777.09 825,-733 807.36,-724.21 790.59,-710.66 776.8,-697.55\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"779.04,-694.85 769.46,-690.34 774.14,-699.85 779.04,-694.85\"/>\n",
       "</g>\n",
       "<!-- scoring_func -->\n",
       "<g id=\"node14\" class=\"node\">\n",
       "<title>scoring_func</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1435,-542.5C1435,-542.5 1339,-542.5 1339,-542.5 1333,-542.5 1327,-536.5 1327,-530.5 1327,-530.5 1327,-490.5 1327,-490.5 1327,-484.5 1333,-478.5 1339,-478.5 1339,-478.5 1435,-478.5 1435,-478.5 1441,-478.5 1447,-484.5 1447,-490.5 1447,-490.5 1447,-530.5 1447,-530.5 1447,-536.5 1441,-542.5 1435,-542.5\"/>\n",
       "<text text-anchor=\"start\" x=\"1338\" y=\"-521.3\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">scoring_func</text>\n",
       "<text text-anchor=\"start\" x=\"1358\" y=\"-493.3\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "<!-- scorer&#45;&gt;scoring_func -->\n",
       "<g id=\"edge22\" class=\"edge\">\n",
       "<title>scorer&#45;&gt;scoring_func</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1163.25,-765.83C1166.12,-764.37 1169.06,-763.06 1172,-762 1239.36,-737.66 1444.69,-785.88 1493,-733 1537.97,-683.78 1524.3,-641.86 1493,-583 1484.47,-566.96 1470.56,-553.87 1455.58,-543.48\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1457.46,-540.53 1447.18,-537.98 1453.63,-546.38 1457.46,-540.53\"/>\n",
       "</g>\n",
       "<!-- cross_validation_folds -->\n",
       "<g id=\"node11\" class=\"node\">\n",
       "<title>cross_validation_folds</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1164,-542.5C1164,-542.5 994,-542.5 994,-542.5 988,-542.5 982,-536.5 982,-530.5 982,-530.5 982,-490.5 982,-490.5 982,-484.5 988,-478.5 994,-478.5 994,-478.5 1164,-478.5 1164,-478.5 1170,-478.5 1176,-484.5 1176,-490.5 1176,-490.5 1176,-530.5 1176,-530.5 1176,-536.5 1170,-542.5 1164,-542.5\"/>\n",
       "<text text-anchor=\"start\" x=\"993\" y=\"-521.3\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">cross_validation_folds</text>\n",
       "<text text-anchor=\"start\" x=\"1044.5\" y=\"-493.3\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Sequence</text>\n",
       "</g>\n",
       "<!-- cross_validation_folds&#45;&gt;hyperparameter_search -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>cross_validation_folds&#45;&gt;hyperparameter_search</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1034.19,-478.47C1018.47,-467.58 1000.66,-455.24 984.31,-443.92\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"986.12,-440.92 975.91,-438.1 982.14,-446.67 986.12,-440.92\"/>\n",
       "</g>\n",
       "<!-- model_config -->\n",
       "<g id=\"node12\" class=\"node\">\n",
       "<title>model_config</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M176.5,-542.5C176.5,-542.5 75.5,-542.5 75.5,-542.5 69.5,-542.5 63.5,-536.5 63.5,-530.5 63.5,-530.5 63.5,-490.5 63.5,-490.5 63.5,-484.5 69.5,-478.5 75.5,-478.5 75.5,-478.5 176.5,-478.5 176.5,-478.5 182.5,-478.5 188.5,-484.5 188.5,-490.5 188.5,-490.5 188.5,-530.5 188.5,-530.5 188.5,-536.5 182.5,-542.5 176.5,-542.5\"/>\n",
       "<text text-anchor=\"start\" x=\"74.5\" y=\"-521.3\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">model_config</text>\n",
       "<text text-anchor=\"start\" x=\"113\" y=\"-493.3\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- model_config&#45;&gt;hyperparameter_search -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>model_config&#45;&gt;hyperparameter_search</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M171.13,-478.38C179.43,-473.87 188.25,-469.81 197,-467 308.96,-431.08 638.93,-415.8 815.85,-410.05\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"816.1,-413.54 825.98,-409.72 815.87,-406.55 816.1,-413.54\"/>\n",
       "</g>\n",
       "<!-- model_config&#45;&gt;best_model -->\n",
       "<g id=\"edge13\" class=\"edge\">\n",
       "<title>model_config&#45;&gt;best_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M154.86,-478.33C205.73,-425.84 317.26,-321.4 436,-279 563.39,-233.51 907.84,-281.42 1056.82,-249.75\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1057.83,-253.11 1066.78,-247.43 1056.25,-246.29 1057.83,-253.11\"/>\n",
       "</g>\n",
       "<!-- test_score -->\n",
       "<g id=\"node16\" class=\"node\">\n",
       "<title>test_score</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1378,-64C1378,-64 1298,-64 1298,-64 1292,-64 1286,-58 1286,-52 1286,-52 1286,-12 1286,-12 1286,-6 1292,0 1298,0 1298,0 1378,0 1378,0 1384,0 1390,-6 1390,-12 1390,-12 1390,-52 1390,-52 1390,-58 1384,-64 1378,-64\"/>\n",
       "<text text-anchor=\"start\" x=\"1297\" y=\"-42.8\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">test_score</text>\n",
       "<text text-anchor=\"start\" x=\"1324\" y=\"-14.8\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">Any</text>\n",
       "</g>\n",
       "<!-- y_test_pred&#45;&gt;test_score -->\n",
       "<g id=\"edge24\" class=\"edge\">\n",
       "<title>y_test_pred&#45;&gt;test_score</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1179.77,-99.97C1209.54,-87.37 1246.03,-71.93 1276.47,-59.04\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1277.87,-62.25 1285.71,-55.13 1275.14,-55.81 1277.87,-62.25\"/>\n",
       "</g>\n",
       "<!-- scoring_func&#45;&gt;hyperparameter_search -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>scoring_func&#45;&gt;hyperparameter_search</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1342.28,-478.47C1334.45,-474.06 1326.19,-470.02 1318,-467 1230.23,-434.67 1125.03,-419.76 1046.48,-412.88\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1046.33,-409.36 1036.07,-412.01 1045.74,-416.33 1046.33,-409.36\"/>\n",
       "</g>\n",
       "<!-- scoring_func&#45;&gt;test_score -->\n",
       "<g id=\"edge25\" class=\"edge\">\n",
       "<title>scoring_func&#45;&gt;test_score</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1374.25,-478.45C1359.74,-440.21 1338,-372.95 1338,-313 1338,-313 1338,-313 1338,-217 1338,-168.15 1338,-112.02 1338,-74.72\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1341.5,-74.25 1338,-64.25 1334.5,-74.25 1341.5,-74.25\"/>\n",
       "</g>\n",
       "<!-- _optuna_distributions_inputs -->\n",
       "<g id=\"node17\" class=\"node\">\n",
       "<title>_optuna_distributions_inputs</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"649,-680.5 347,-680.5 347,-635.5 649,-635.5 649,-680.5\"/>\n",
       "<text text-anchor=\"start\" x=\"362\" y=\"-653.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">optuna_distributions_override</text>\n",
       "<text text-anchor=\"start\" x=\"575\" y=\"-653.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n",
       "</g>\n",
       "<!-- _optuna_distributions_inputs&#45;&gt;optuna_distributions -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>_optuna_distributions_inputs&#45;&gt;optuna_distributions</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M541.1,-635.37C587.51,-612.09 662.08,-574.68 717.14,-547.05\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"718.77,-550.15 726.14,-542.54 715.63,-543.89 718.77,-550.15\"/>\n",
       "</g>\n",
       "<!-- _hyperparameter_search_inputs -->\n",
       "<g id=\"node18\" class=\"node\">\n",
       "<title>_hyperparameter_search_inputs</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"677.5,-554 206.5,-554 206.5,-467 677.5,-467 677.5,-554\"/>\n",
       "<text text-anchor=\"start\" x=\"250.5\" y=\"-527.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_train</text>\n",
       "<text text-anchor=\"start\" x=\"334\" y=\"-527.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n",
       "<text text-anchor=\"start\" x=\"251\" y=\"-506.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_train</text>\n",
       "<text text-anchor=\"start\" x=\"334\" y=\"-506.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n",
       "<text text-anchor=\"start\" x=\"222\" y=\"-485.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">n_optuna_trials</text>\n",
       "<text text-anchor=\"start\" x=\"489\" y=\"-485.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "</g>\n",
       "<!-- _hyperparameter_search_inputs&#45;&gt;hyperparameter_search -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>_hyperparameter_search_inputs&#45;&gt;hyperparameter_search</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M644.78,-467C703.12,-454.77 764.73,-441.85 815.78,-431.15\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"816.78,-434.52 825.84,-429.04 815.34,-427.67 816.78,-434.52\"/>\n",
       "</g>\n",
       "<!-- _study_inputs -->\n",
       "<g id=\"node19\" class=\"node\">\n",
       "<title>_study_inputs</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1028,-722.5 834,-722.5 834,-593.5 1028,-593.5 1028,-722.5\"/>\n",
       "<text text-anchor=\"start\" x=\"852\" y=\"-695.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">load_if_exists</text>\n",
       "<text text-anchor=\"start\" x=\"968.5\" y=\"-695.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">bool</text>\n",
       "<text text-anchor=\"start\" x=\"874.5\" y=\"-674.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">pruner</text>\n",
       "<text text-anchor=\"start\" x=\"954\" y=\"-674.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n",
       "<text text-anchor=\"start\" x=\"856\" y=\"-653.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">study_name</text>\n",
       "<text text-anchor=\"start\" x=\"954\" y=\"-653.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n",
       "<text text-anchor=\"start\" x=\"849\" y=\"-632.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">study_storage</text>\n",
       "<text text-anchor=\"start\" x=\"954\" y=\"-632.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n",
       "<text text-anchor=\"start\" x=\"870\" y=\"-611.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">sampler</text>\n",
       "<text text-anchor=\"start\" x=\"954\" y=\"-611.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n",
       "</g>\n",
       "<!-- _study_inputs&#45;&gt;study -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>_study_inputs&#45;&gt;study</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M931,-593.18C931,-579.46 931,-565.31 931,-552.72\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"934.5,-552.56 931,-542.56 927.5,-552.56 934.5,-552.56\"/>\n",
       "</g>\n",
       "<!-- _best_model_inputs -->\n",
       "<g id=\"node20\" class=\"node\">\n",
       "<title>_best_model_inputs</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1005,-345 591,-345 591,-279 1005,-279 1005,-345\"/>\n",
       "<text text-anchor=\"start\" x=\"606\" y=\"-318.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_train</text>\n",
       "<text text-anchor=\"start\" x=\"661\" y=\"-318.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n",
       "<text text-anchor=\"start\" x=\"606.5\" y=\"-297.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_train</text>\n",
       "<text text-anchor=\"start\" x=\"661\" y=\"-297.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n",
       "</g>\n",
       "<!-- _best_model_inputs&#45;&gt;best_model -->\n",
       "<g id=\"edge15\" class=\"edge\">\n",
       "<title>_best_model_inputs&#45;&gt;best_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M958.71,-278.99C989.19,-271.47 1020.75,-262.68 1057.08,-250.1\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1058.67,-253.26 1066.94,-246.64 1056.35,-246.65 1058.67,-253.26\"/>\n",
       "</g>\n",
       "<!-- _cross_validation_folds_inputs -->\n",
       "<g id=\"node21\" class=\"node\">\n",
       "<title>_cross_validation_folds_inputs</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1484,-733 1046,-733 1046,-583 1484,-583 1484,-733\"/>\n",
       "<text text-anchor=\"start\" x=\"1073\" y=\"-706.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_train</text>\n",
       "<text text-anchor=\"start\" x=\"1140\" y=\"-706.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n",
       "<text text-anchor=\"start\" x=\"1073\" y=\"-685.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">stratify</text>\n",
       "<text text-anchor=\"start\" x=\"1289.5\" y=\"-685.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">bool</text>\n",
       "<text text-anchor=\"start\" x=\"1073.5\" y=\"-664.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">shuffle</text>\n",
       "<text text-anchor=\"start\" x=\"1289.5\" y=\"-664.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">bool</text>\n",
       "<text text-anchor=\"start\" x=\"1080.5\" y=\"-643.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n",
       "<text text-anchor=\"start\" x=\"1295\" y=\"-643.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"1073.5\" y=\"-622.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_train</text>\n",
       "<text text-anchor=\"start\" x=\"1140\" y=\"-622.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n",
       "<text text-anchor=\"start\" x=\"1061\" y=\"-601.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">n_cv_folds</text>\n",
       "<text text-anchor=\"start\" x=\"1295\" y=\"-601.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "</g>\n",
       "<!-- _cross_validation_folds_inputs&#45;&gt;cross_validation_folds -->\n",
       "<g id=\"edge18\" class=\"edge\">\n",
       "<title>_cross_validation_folds_inputs&#45;&gt;cross_validation_folds</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1170.37,-582.97C1155.35,-571.23 1140.4,-559.53 1127.06,-549.1\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1129.02,-546.18 1118.99,-542.78 1124.71,-551.7 1129.02,-546.18\"/>\n",
       "</g>\n",
       "<!-- _model_config_inputs -->\n",
       "<g id=\"node22\" class=\"node\">\n",
       "<title>_model_config_inputs</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"252,-691 0,-691 0,-625 252,-625 252,-691\"/>\n",
       "<text text-anchor=\"start\" x=\"76.5\" y=\"-664.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">seed</text>\n",
       "<text text-anchor=\"start\" x=\"198\" y=\"-664.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">int</text>\n",
       "<text text-anchor=\"start\" x=\"15\" y=\"-643.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">model_config_override</text>\n",
       "<text text-anchor=\"start\" x=\"178\" y=\"-643.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Optional</text>\n",
       "</g>\n",
       "<!-- _model_config_inputs&#45;&gt;model_config -->\n",
       "<g id=\"edge19\" class=\"edge\">\n",
       "<title>_model_config_inputs&#45;&gt;model_config</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M126,-624.71C126,-603.51 126,-575.59 126,-552.78\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"129.5,-552.7 126,-542.7 122.5,-552.7 129.5,-552.7\"/>\n",
       "</g>\n",
       "<!-- _y_test_pred_inputs -->\n",
       "<g id=\"node23\" class=\"node\">\n",
       "<title>_y_test_pred_inputs</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1049,-240.5 641,-240.5 641,-195.5 1049,-195.5 1049,-240.5\"/>\n",
       "<text text-anchor=\"start\" x=\"656\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">X_test</text>\n",
       "<text text-anchor=\"start\" x=\"705\" y=\"-213.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n",
       "</g>\n",
       "<!-- _y_test_pred_inputs&#45;&gt;y_test_pred -->\n",
       "<g id=\"edge21\" class=\"edge\">\n",
       "<title>_y_test_pred_inputs&#45;&gt;y_test_pred</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M910.88,-195.43C954.77,-181.07 1012.25,-162.25 1056.3,-147.83\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1057.64,-151.08 1066.06,-144.64 1055.46,-144.42 1057.64,-151.08\"/>\n",
       "</g>\n",
       "<!-- _test_score_inputs -->\n",
       "<g id=\"node24\" class=\"node\">\n",
       "<title>_test_score_inputs</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1773.5,-147.5 1366.5,-147.5 1366.5,-102.5 1773.5,-102.5 1773.5,-147.5\"/>\n",
       "<text text-anchor=\"start\" x=\"1382\" y=\"-120.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">y_test</text>\n",
       "<text text-anchor=\"start\" x=\"1430\" y=\"-120.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">numpy.ndarray | pandas.core.frame.DataFrame</text>\n",
       "</g>\n",
       "<!-- _test_score_inputs&#45;&gt;test_score -->\n",
       "<g id=\"edge26\" class=\"edge\">\n",
       "<title>_test_score_inputs&#45;&gt;test_score</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1515.02,-102.43C1480.33,-88.83 1435.47,-71.23 1399.59,-57.16\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1400.78,-53.87 1390.2,-53.47 1398.23,-60.38 1400.78,-53.87\"/>\n",
       "</g>\n",
       "<!-- config -->\n",
       "<g id=\"node25\" class=\"node\">\n",
       "<title>config</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"1481.5,-812 1428.5,-812 1428.5,-776 1487.5,-776 1487.5,-806 1481.5,-812\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"1481.5,-812 1481.5,-806 \"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"1487.5,-806 1481.5,-806 \"/>\n",
       "<text text-anchor=\"middle\" x=\"1458\" y=\"-790.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">config</text>\n",
       "</g>\n",
       "<!-- input -->\n",
       "<g id=\"node26\" class=\"node\">\n",
       "<title>input</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" points=\"1410.5,-812.5 1351.5,-812.5 1351.5,-775.5 1410.5,-775.5 1410.5,-812.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"1381\" y=\"-790.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">input</text>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node27\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M1321,-812.5C1321,-812.5 1265,-812.5 1265,-812.5 1259,-812.5 1253,-806.5 1253,-800.5 1253,-800.5 1253,-787.5 1253,-787.5 1253,-781.5 1259,-775.5 1265,-775.5 1265,-775.5 1321,-775.5 1321,-775.5 1327,-775.5 1333,-781.5 1333,-787.5 1333,-787.5 1333,-800.5 1333,-800.5 1333,-806.5 1327,-812.5 1321,-812.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"1293\" y=\"-790.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7f5be7bae260>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "config = dict(\n",
    "    task=\"classification\"\n",
    ")\n",
    "\n",
    "dr = (\n",
    "    driver.Builder()\n",
    "    .with_modules(xgboost_optuna)\n",
    "    .with_config(config)\n",
    "    .build()\n",
    ")\n",
    "\n",
    "display(dr.display_all_functions(None, orient=\"TB\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_breast_cancer\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "# Load the Boston Housing dataset (regression example)\n",
    "data = load_breast_cancer()\n",
    "X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[I 2023-10-27 19:07:24,676] A new study created in memory with name: no-name-f5a6ea86-4b6c-4805-b5b5-d3db18771b86\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['base_model',\n",
      " 'best_model',\n",
      " 'cross_validation_folds',\n",
      " 'hyperparameter_search',\n",
      " 'study_results',\n",
      " 'best_hyperparameters',\n",
      " 'model_config',\n",
      " 'optuna_distributions',\n",
      " 'scorer',\n",
      " 'scoring_func',\n",
      " 'higher_is_better',\n",
      " 'study',\n",
      " 'study_results_df',\n",
      " 'test_score',\n",
      " 'y_test_pred']\n"
     ]
    }
   ],
   "source": [
    "final_vars = [v for v in dr.graph.get_nodes() if v._tags.get(\"module\") == \"__init__\"]\n",
    "\n",
    "inputs = dict(\n",
    "    X_train=X_train,\n",
    "    y_train=y_train,\n",
    "    X_test=X_test,\n",
    "    y_test=y_test,\n",
    ")\n",
    "\n",
    "overrides = dict()\n",
    "\n",
    "res = dr.execute(\n",
    "    final_vars=final_vars,\n",
    "    inputs=inputs,\n",
    "    overrides=overrides\n",
    ")\n",
    "\n",
    "pprint(list(res.keys()), width=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>number</th>\n",
       "      <th>value</th>\n",
       "      <th>datetime_start</th>\n",
       "      <th>datetime_complete</th>\n",
       "      <th>duration</th>\n",
       "      <th>params_colsample_bytree</th>\n",
       "      <th>params_gamma</th>\n",
       "      <th>params_learning_rate</th>\n",
       "      <th>params_max_delta_step</th>\n",
       "      <th>params_max_depth</th>\n",
       "      <th>params_min_child_weight</th>\n",
       "      <th>params_n_estimators</th>\n",
       "      <th>state</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.938451</td>\n",
       "      <td>2023-10-27 19:07:24.677858</td>\n",
       "      <td>2023-10-27 19:07:25.159863</td>\n",
       "      <td>0 days 00:00:00.482005</td>\n",
       "      <td>0.662668</td>\n",
       "      <td>18.162984</td>\n",
       "      <td>0.046331</td>\n",
       "      <td>6</td>\n",
       "      <td>10</td>\n",
       "      <td>1</td>\n",
       "      <td>700</td>\n",
       "      <td>COMPLETE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.966989</td>\n",
       "      <td>2023-10-27 19:07:25.160009</td>\n",
       "      <td>2023-10-27 19:07:25.289743</td>\n",
       "      <td>0 days 00:00:00.129734</td>\n",
       "      <td>0.781939</td>\n",
       "      <td>0.281130</td>\n",
       "      <td>0.084685</td>\n",
       "      <td>8</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>250</td>\n",
       "      <td>COMPLETE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0.971390</td>\n",
       "      <td>2023-10-27 19:07:25.289875</td>\n",
       "      <td>2023-10-27 19:07:25.483036</td>\n",
       "      <td>0 days 00:00:00.193161</td>\n",
       "      <td>0.752015</td>\n",
       "      <td>0.047018</td>\n",
       "      <td>0.090246</td>\n",
       "      <td>3</td>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "      <td>700</td>\n",
       "      <td>COMPLETE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.958188</td>\n",
       "      <td>2023-10-27 19:07:25.483285</td>\n",
       "      <td>2023-10-27 19:07:26.049349</td>\n",
       "      <td>0 days 00:00:00.566064</td>\n",
       "      <td>0.952319</td>\n",
       "      <td>0.357577</td>\n",
       "      <td>0.014098</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>7</td>\n",
       "      <td>550</td>\n",
       "      <td>COMPLETE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0.958203</td>\n",
       "      <td>2023-10-27 19:07:26.049490</td>\n",
       "      <td>2023-10-27 19:07:26.284952</td>\n",
       "      <td>0 days 00:00:00.235462</td>\n",
       "      <td>0.566256</td>\n",
       "      <td>2.996323</td>\n",
       "      <td>0.026967</td>\n",
       "      <td>4</td>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "      <td>400</td>\n",
       "      <td>COMPLETE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>5</td>\n",
       "      <td>0.969197</td>\n",
       "      <td>2023-10-27 19:07:26.285088</td>\n",
       "      <td>2023-10-27 19:07:26.600701</td>\n",
       "      <td>0 days 00:00:00.315613</td>\n",
       "      <td>0.538976</td>\n",
       "      <td>0.066454</td>\n",
       "      <td>0.032400</td>\n",
       "      <td>1</td>\n",
       "      <td>10</td>\n",
       "      <td>4</td>\n",
       "      <td>700</td>\n",
       "      <td>COMPLETE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>6</td>\n",
       "      <td>0.947223</td>\n",
       "      <td>2023-10-27 19:07:26.600836</td>\n",
       "      <td>2023-10-27 19:07:26.743772</td>\n",
       "      <td>0 days 00:00:00.142936</td>\n",
       "      <td>0.727392</td>\n",
       "      <td>11.238290</td>\n",
       "      <td>0.042310</td>\n",
       "      <td>3</td>\n",
       "      <td>6</td>\n",
       "      <td>2</td>\n",
       "      <td>250</td>\n",
       "      <td>COMPLETE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7</td>\n",
       "      <td>0.958203</td>\n",
       "      <td>2023-10-27 19:07:26.743911</td>\n",
       "      <td>2023-10-27 19:07:27.098564</td>\n",
       "      <td>0 days 00:00:00.354653</td>\n",
       "      <td>0.510622</td>\n",
       "      <td>5.068645</td>\n",
       "      <td>0.016495</td>\n",
       "      <td>4</td>\n",
       "      <td>10</td>\n",
       "      <td>1</td>\n",
       "      <td>700</td>\n",
       "      <td>COMPLETE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>8</td>\n",
       "      <td>0.967004</td>\n",
       "      <td>2023-10-27 19:07:27.098709</td>\n",
       "      <td>2023-10-27 19:07:27.603019</td>\n",
       "      <td>0 days 00:00:00.504310</td>\n",
       "      <td>0.674292</td>\n",
       "      <td>0.122961</td>\n",
       "      <td>0.011985</td>\n",
       "      <td>10</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>550</td>\n",
       "      <td>COMPLETE</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>9</td>\n",
       "      <td>0.975790</td>\n",
       "      <td>2023-10-27 19:07:27.603178</td>\n",
       "      <td>2023-10-27 19:07:27.757412</td>\n",
       "      <td>0 days 00:00:00.154234</td>\n",
       "      <td>0.547927</td>\n",
       "      <td>0.035147</td>\n",
       "      <td>0.078895</td>\n",
       "      <td>5</td>\n",
       "      <td>6</td>\n",
       "      <td>2</td>\n",
       "      <td>250</td>\n",
       "      <td>COMPLETE</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   number     value             datetime_start          datetime_complete  \\\n",
       "0       0  0.938451 2023-10-27 19:07:24.677858 2023-10-27 19:07:25.159863   \n",
       "1       1  0.966989 2023-10-27 19:07:25.160009 2023-10-27 19:07:25.289743   \n",
       "2       2  0.971390 2023-10-27 19:07:25.289875 2023-10-27 19:07:25.483036   \n",
       "3       3  0.958188 2023-10-27 19:07:25.483285 2023-10-27 19:07:26.049349   \n",
       "4       4  0.958203 2023-10-27 19:07:26.049490 2023-10-27 19:07:26.284952   \n",
       "5       5  0.969197 2023-10-27 19:07:26.285088 2023-10-27 19:07:26.600701   \n",
       "6       6  0.947223 2023-10-27 19:07:26.600836 2023-10-27 19:07:26.743772   \n",
       "7       7  0.958203 2023-10-27 19:07:26.743911 2023-10-27 19:07:27.098564   \n",
       "8       8  0.967004 2023-10-27 19:07:27.098709 2023-10-27 19:07:27.603019   \n",
       "9       9  0.975790 2023-10-27 19:07:27.603178 2023-10-27 19:07:27.757412   \n",
       "\n",
       "                duration  params_colsample_bytree  params_gamma  \\\n",
       "0 0 days 00:00:00.482005                 0.662668     18.162984   \n",
       "1 0 days 00:00:00.129734                 0.781939      0.281130   \n",
       "2 0 days 00:00:00.193161                 0.752015      0.047018   \n",
       "3 0 days 00:00:00.566064                 0.952319      0.357577   \n",
       "4 0 days 00:00:00.235462                 0.566256      2.996323   \n",
       "5 0 days 00:00:00.315613                 0.538976      0.066454   \n",
       "6 0 days 00:00:00.142936                 0.727392     11.238290   \n",
       "7 0 days 00:00:00.354653                 0.510622      5.068645   \n",
       "8 0 days 00:00:00.504310                 0.674292      0.122961   \n",
       "9 0 days 00:00:00.154234                 0.547927      0.035147   \n",
       "\n",
       "   params_learning_rate  params_max_delta_step  params_max_depth  \\\n",
       "0              0.046331                      6                10   \n",
       "1              0.084685                      8                 3   \n",
       "2              0.090246                      3                 7   \n",
       "3              0.014098                      0                 3   \n",
       "4              0.026967                      4                 7   \n",
       "5              0.032400                      1                10   \n",
       "6              0.042310                      3                 6   \n",
       "7              0.016495                      4                10   \n",
       "8              0.011985                     10                 3   \n",
       "9              0.078895                      5                 6   \n",
       "\n",
       "   params_min_child_weight  params_n_estimators     state  \n",
       "0                        1                  700  COMPLETE  \n",
       "1                        1                  250  COMPLETE  \n",
       "2                        1                  700  COMPLETE  \n",
       "3                        7                  550  COMPLETE  \n",
       "4                        1                  400  COMPLETE  \n",
       "5                        4                  700  COMPLETE  \n",
       "6                        2                  250  COMPLETE  \n",
       "7                        1                  700  COMPLETE  \n",
       "8                        4                  550  COMPLETE  \n",
       "9                        2                  250  COMPLETE  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"study_results_df\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.956140350877193"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"test_score\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'n_estimators': 250,\n",
       " 'learning_rate': 0.07889492017880499,\n",
       " 'max_depth': 6,\n",
       " 'gamma': 0.03514738204045661,\n",
       " 'colsample_bytree': 0.5479271132206868,\n",
       " 'min_child_weight': 2,\n",
       " 'max_delta_step': 5}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res[\"best_hyperparameters\"]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
